Source code for nlp_architect.api.ner_api

# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import pickle
from os import makedirs, path, sys

import numpy as np

from nlp_architect.api.abstract_api import AbstractApi
from nlp_architect.models.ner_crf import NERCRF
from nlp_architect import LIBRARY_OUT
from nlp_architect.utils.generic import pad_sentences
from import download_unlicensed_file
from nlp_architect.utils.text import SpacyInstance, bio_to_spans

[docs]class NerApi(AbstractApi): """ NER model API """ model_dir = str(LIBRARY_OUT / "ner-pretrained") pretrained_model = path.join(model_dir, "model_v4.h5") pretrained_model_info = path.join(model_dir, "model_info_v4.dat") def __init__(self, prompt=True): self.model = None self.model_info = None self.word_vocab = None self.y_vocab = None self.char_vocab = None self._download_pretrained_model(prompt) self.nlp = SpacyInstance(disable=["tagger", "ner", "parser", "vectors", "textcat"]) @staticmethod def _prompt(): response = input("\nTo download '{}', please enter YES: ".format("ner")) res = response.lower().strip() if res == "yes" or (len(res) == 1 and res == "y"): print("Downloading {}...".format("ner")) responded_yes = True else: print("Download declined. Response received {} != YES|Y. ".format(res)) responded_yes = False return responded_yes def _download_pretrained_model(self, prompt=True): """Downloads the pre-trained BIST model if non-existent.""" model_exists = path.isfile(self.pretrained_model) model_info_exists = path.isfile(self.pretrained_model_info) if not model_exists or not model_info_exists: print( "The pre-trained models to be downloaded for the NER dataset " "are licensed under Apache 2.0. By downloading, you accept the terms " "and conditions provided by the license" ) makedirs(self.model_dir, exist_ok=True) if prompt is True: agreed = NerApi._prompt() if agreed is False: sys.exit(0) download_unlicensed_file( "" "/models/ner/", "model_v4.h5", self.pretrained_model, ) download_unlicensed_file( "" "/models/ner/", "model_info_v4.dat", self.pretrained_model_info, ) print("Done.")
[docs] def load_model(self): self.model = NERCRF() self.model.load(self.pretrained_model) with open(self.pretrained_model_info, "rb") as fp: model_info = pickle.load(fp) self.word_vocab = model_info["word_vocab"] self.y_vocab = {v: k for k, v in model_info["y_vocab"].items()} self.char_vocab = model_info["char_vocab"]
[docs] @staticmethod def pretty_print(text, tags): spans = [] for s, e, tag in bio_to_spans(text, tags): spans.append({"start": s, "end": e, "type": tag}) ents = dict((obj["type"].lower(), obj) for obj in spans).keys() ret = { "doc_text": " ".join(text), "annotation_set": list(ents), "spans": spans, "title": "None", } print({"doc": ret, "type": "high_level"}) return {"doc": ret, "type": "high_level"}
[docs] def process_text(self, text): input_text = " ".join(text.strip().split()) return self.nlp.tokenize(input_text)
[docs] def vectorize(self, doc, vocab, char_vocab): words = np.asarray([vocab[w.lower()] if w.lower() in vocab else 1 for w in doc]).reshape( 1, -1 ) sentence_chars = [] for w in doc: word_chars = [] for c in w: if c in char_vocab: _cid = char_vocab[c] else: _cid = 1 word_chars.append(_cid) sentence_chars.append(word_chars) sentence_chars = np.expand_dims( pad_sentences(sentence_chars, self.model.word_length), axis=0 ) return words, sentence_chars
[docs] def inference(self, doc): text_arr = self.process_text(doc) doc_vec = self.vectorize(text_arr, self.word_vocab, self.char_vocab) seq_len = np.array([len(text_arr)]).reshape(-1, 1) inputs = list(doc_vec) # pylint: disable=no-member inputs = list(doc_vec) + [seq_len] doc_ner = self.model.predict(inputs, batch_size=1).argmax(2).flatten() tags = [self.y_vocab.get(n, None) for n in doc_ner] return self.pretty_print(text_arr, tags)